Re-weighting#
PolicyEngine-UK primarily relies on the Family Resources Survey, which has known issues with non-capture of households at the bottom and top of the income distribution. To correct for this, we apply a weight modification, optimised using gradient descent to minimise survey error against a diverse selection of targeting statistics. These include:
Regional populations
Household populations
Population by tenure type
Population by Council Tax band
Country-level program statistics
UK-wide program aggregates
UK-wide program caseloads
The graph below shows the effect of the optimisation on each of these, compared to their starting values (under original FRS weights). All loss subfunctions improve from their starting values.
Show code cell source
import pandas as pd
import numpy as np
import pandas as pd
import plotly.express as px
df = pd.read_csv(
"https://github.com/PolicyEngine/openfisca-uk-reweighting/raw/master/no_val_split/training_log_run_1.csv.gz",
)
ldf = (
df.groupby(["category", "epoch"])
.sum()
.reset_index()
.pivot(columns="category", values="loss", index="epoch")
)
ldf /= ldf.loc[0]
ldf -= 1
ldf = ldf.reset_index().melt(id_vars=["epoch"])
import plotly.express as px
ldf["hover"] = [
f"At epoch {epoch}, the total loss from targets <br>in the category <b>{category}</b> <br>has <b>{'risen' if value > 0 else 'fallen'}</b> by <b>{abs(value):.1%}</b>."
for epoch, category, value in zip(ldf.epoch, ldf.category, ldf.value)
]
px.line(
ldf, x="epoch", y="value", color="category", custom_data=[ldf.hover]
).update_traces(hovertemplate="%{customdata[0]}").update_layout(
title="Training performance by category",
height=600,
width=800,
xaxis_title="Epoch",
yaxis_title="Loss change",
legend_title="Category",
yaxis_range=(-1, 0),
yaxis_tickformat=".0%",
)
Changes to distributions#
Validation#
During initial training, we split the targets into training and validation groups (80%/20%), performing 5-fold cross-validation. The graph below shows the performance of validation metrics in each fold, as well as the average over the five folds.
Show code cell source
df = pd.read_csv(
"https://github.com/PolicyEngine/openfisca-uk-reweighting/raw/master/train_val_split/training_log.csv.gz",
compression="gzip",
)
xdf = pd.DataFrame()
for validation_type in (True, False, "Both"):
if isinstance(validation_type, bool):
condition = df.validation == validation_type
else:
condition = df.validation | ~df.validation
x = (
df[condition]
.groupby(["run_id", "epoch"])
.loss.sum()
.reset_index()
.pivot(columns="run_id", values="loss", index="epoch")
)
x /= x.loc[0]
x -= 1
x = x.dropna()
x["Average"] = x.mean(axis=1)
x["Type"] = {
True: "Validation",
False: "Training",
"Both": "Training + Validation",
}[validation_type]
xdf = pd.concat([xdf, x])
px.line(
xdf,
y=xdf.columns,
animation_frame="Type",
color_discrete_sequence=["lightgrey"] * 5 + ["grey"],
).update_layout(
title="5-fold cross-validation training",
yaxis_title="Relative loss change",
yaxis_tickformat=".0%",
xaxis_title="Epoch",
legend_title="Fold",
width=800,
height=800,
)
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[2], line 1
----> 1 df = pd.read_csv(
2 "https://github.com/PolicyEngine/openfisca-uk-reweighting/raw/master/train_val_split/training_log.csv.gz",
3 compression="gzip",
4 )
5 xdf = pd.DataFrame()
6 for validation_type in (True, False, "Both"):
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:912, in read_csv(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)
899 kwds_defaults = _refine_defaults_read(
900 dialect,
901 delimiter,
(...)
908 dtype_backend=dtype_backend,
909 )
910 kwds.update(kwds_defaults)
--> 912 return _read(filepath_or_buffer, kwds)
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:577, in _read(filepath_or_buffer, kwds)
574 _validate_names(kwds.get("names", None))
576 # Create the parser.
--> 577 parser = TextFileReader(filepath_or_buffer, **kwds)
579 if chunksize or iterator:
580 return parser
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:1407, in TextFileReader.__init__(self, f, engine, **kwds)
1404 self.options["has_index_names"] = kwds["has_index_names"]
1406 self.handles: IOHandles | None = None
-> 1407 self._engine = self._make_engine(f, self.engine)
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/parsers/readers.py:1661, in TextFileReader._make_engine(self, f, engine)
1659 if "b" not in mode:
1660 mode += "b"
-> 1661 self.handles = get_handle(
1662 f,
1663 mode,
1664 encoding=self.options.get("encoding", None),
1665 compression=self.options.get("compression", None),
1666 memory_map=self.options.get("memory_map", False),
1667 is_text=is_text,
1668 errors=self.options.get("encoding_errors", "strict"),
1669 storage_options=self.options.get("storage_options", None),
1670 )
1671 assert self.handles is not None
1672 f = self.handles.handle
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/common.py:716, in get_handle(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)
713 codecs.lookup_error(errors)
715 # open URLs
--> 716 ioargs = _get_filepath_or_buffer(
717 path_or_buf,
718 encoding=encoding,
719 compression=compression,
720 mode=mode,
721 storage_options=storage_options,
722 )
724 handle = ioargs.filepath_or_buffer
725 handles: list[BaseBuffer]
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/common.py:368, in _get_filepath_or_buffer(filepath_or_buffer, encoding, compression, mode, storage_options)
366 # assuming storage_options is to be interpreted as headers
367 req_info = urllib.request.Request(filepath_or_buffer, headers=storage_options)
--> 368 with urlopen(req_info) as req:
369 content_encoding = req.headers.get("Content-Encoding", None)
370 if content_encoding == "gzip":
371 # Override compression based on Content-Encoding header
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/pandas/io/common.py:270, in urlopen(*args, **kwargs)
264 """
265 Lazy-import wrapper for stdlib urlopen, as that imports a big chunk of
266 the stdlib.
267 """
268 import urllib.request
--> 270 return urllib.request.urlopen(*args, **kwargs)
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:214, in urlopen(url, data, timeout, cafile, capath, cadefault, context)
212 else:
213 opener = _opener
--> 214 return opener.open(url, data, timeout)
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:523, in OpenerDirector.open(self, fullurl, data, timeout)
521 for processor in self.process_response.get(protocol, []):
522 meth = getattr(processor, meth_name)
--> 523 response = meth(req, response)
525 return response
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:632, in HTTPErrorProcessor.http_response(self, request, response)
629 # According to RFC 2616, "2xx" code indicates that the client's
630 # request was successfully received, understood, and accepted.
631 if not (200 <= code < 300):
--> 632 response = self.parent.error(
633 'http', request, response, code, msg, hdrs)
635 return response
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:555, in OpenerDirector.error(self, proto, *args)
553 http_err = 0
554 args = (dict, proto, meth_name) + args
--> 555 result = self._call_chain(*args)
556 if result:
557 return result
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:494, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
492 for handler in handlers:
493 func = getattr(handler, meth_name)
--> 494 result = func(*args)
495 if result is not None:
496 return result
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:747, in HTTPRedirectHandler.http_error_302(self, req, fp, code, msg, headers)
744 fp.read()
745 fp.close()
--> 747 return self.parent.open(new, timeout=req.timeout)
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:517, in OpenerDirector.open(self, fullurl, data, timeout)
514 req = meth(req)
516 sys.audit('urllib.Request', req.full_url, req.data, req.headers, req.get_method())
--> 517 response = self._open(req, data)
519 # post-process response
520 meth_name = protocol+"_response"
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:534, in OpenerDirector._open(self, req, data)
531 return result
533 protocol = req.type
--> 534 result = self._call_chain(self.handle_open, protocol, protocol +
535 '_open', req)
536 if result:
537 return result
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:494, in OpenerDirector._call_chain(self, chain, kind, meth_name, *args)
492 for handler in handlers:
493 func = getattr(handler, meth_name)
--> 494 result = func(*args)
495 if result is not None:
496 return result
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:1389, in HTTPSHandler.https_open(self, req)
1388 def https_open(self, req):
-> 1389 return self.do_open(http.client.HTTPSConnection, req,
1390 context=self._context, check_hostname=self._check_hostname)
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/urllib/request.py:1350, in AbstractHTTPHandler.do_open(self, http_class, req, **http_conn_args)
1348 except OSError as err: # timeout error
1349 raise URLError(err)
-> 1350 r = h.getresponse()
1351 except:
1352 h.close()
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/http/client.py:1377, in HTTPConnection.getresponse(self)
1375 try:
1376 try:
-> 1377 response.begin()
1378 except ConnectionError:
1379 self.close()
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/http/client.py:320, in HTTPResponse.begin(self)
318 # read until we get a non-100 response
319 while True:
--> 320 version, status, reason = self._read_status()
321 if status != CONTINUE:
322 break
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/http/client.py:281, in HTTPResponse._read_status(self)
280 def _read_status(self):
--> 281 line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1")
282 if len(line) > _MAXLINE:
283 raise LineTooLong("status line")
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/socket.py:704, in SocketIO.readinto(self, b)
702 while True:
703 try:
--> 704 return self._sock.recv_into(b)
705 except timeout:
706 self._timeout_occurred = True
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/ssl.py:1242, in SSLSocket.recv_into(self, buffer, nbytes, flags)
1238 if flags != 0:
1239 raise ValueError(
1240 "non-zero flags not allowed in calls to recv_into() on %s" %
1241 self.__class__)
-> 1242 return self.read(nbytes, buffer)
1243 else:
1244 return super().recv_into(buffer, nbytes, flags)
File /opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/ssl.py:1100, in SSLSocket.read(self, len, buffer)
1098 try:
1099 if buffer is not None:
-> 1100 return self._sslobj.read(len, buffer)
1101 else:
1102 return self._sslobj.read(len)
KeyboardInterrupt:
The below chart visualises the effect of the training process on each individual training and validation metric, by epoch.
Show code cell source
df["rel_error"] = df.pred / df.actual - 1
df["Type"] = np.where(df.validation, "Validation", "Training")
STEP_SIZE = 50
cdf = df[df.epoch % STEP_SIZE == 0]
cdf = cdf[
(cdf.category == "Budgetary impact")
| (cdf.category == "UK-wide program aggregates")
]
fig = px.scatter(
cdf,
animation_frame="epoch",
x="actual",
y="rel_error",
color="Type",
hover_data=df.columns,
opacity=0.2,
)
layout = dict(
title="Target metrics",
width=800,
height=800,
legend_title="Type",
yaxis_title="Relative error",
yaxis_tickformat=".1%",
xaxis_tickprefix="£",
xaxis_title="Actual value",
yaxis_range=(-1, 1),
)
fig.update_layout(**layout)
for i, frame in enumerate(fig.frames):
frame.layout.update(layout)
frame.layout[
"title"
] = f"Budgetary impact target metric performance at {i * STEP_SIZE:,} epochs"
for step in fig.layout.sliders[0].steps:
step["args"][1]["frame"]["redraw"] = True
for button in fig.layout.updatemenus[0].buttons:
button["args"][1]["frame"]["redraw"] = True
import gif
import plotly.graph_objects as go
gif.save(
[
gif.frame(lambda: go.Figure(data=frame.data, layout=frame.layout))()
for frame in fig.frames
],
"scatterplot.gif",
duration=3_000 / len(fig.frames),
)
fig